{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一.联合概率分布\n",
    "前两节我们分别介绍了线性回归的贝叶斯估计以及它的一种近似计算方法：证据近似，贝叶斯估计的图模型结构如下  \n",
    "![avatar](./source/15_VI_贝叶斯线性回归图模型.png)\n",
    "为了方便，我们将$\\beta$看做一个常数，回忆一下上一节，我们的似然函数和共轭先验如下：   \n",
    "\n",
    "$$\n",
    "p(t\\mid w)=\\prod_{n=1}^N N(t_n\\mid w^T\\phi(x_n),\\beta^{-1})\\\\\n",
    "p(w\\mid \\alpha)=N(w\\mid 0,\\alpha^{-1}I)\n",
    "$$  \n",
    "这里，$X=\\{x_1,x_2,...,x_N\\}$，$I$为单位矩阵，与上一节给$\\alpha$取定值不同，我们这里将$\\alpha$看做来源于一个Gamma分布，即   \n",
    "$$\n",
    "p(\\alpha)=Gam(\\alpha\\mid a_0,b_0)\n",
    "$$   \n",
    "所以，所有变量上的联合概率分布就有了   \n",
    "\n",
    "$$\n",
    "p(t,w,\\alpha)=p(t\\mid w)p(w\\mid\\alpha)p(\\alpha)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二.变分分布\n",
    "接下来，我们就是要构建一个变分分布$q(w,\\alpha)$去近似后验概率分布$p(w,\\alpha\\mid t)$，利用平均场假设，令：   \n",
    "\n",
    "$$\n",
    "q(w,\\alpha)=q(w)q(\\alpha)\n",
    "$$  \n",
    "\n",
    "根据VI第一节最后推导的公式，我们可以直接写出最优解：   \n",
    "\n",
    "$$\n",
    "ln\\ q^*(\\alpha)=ln\\ p(\\alpha)+E_w[ln\\ p(w\\mid\\alpha)]+const\\\\\n",
    "=(a_0-1)ln\\ \\alpha-b_0\\alpha+\\frac{M}{2}ln\\ \\alpha-\\frac{\\alpha}{2}E[w^Tw]+const\n",
    "$$  \n",
    "\n",
    "可以发现这个Gamma分布取对数的形式，通过观察系数项，我们有  \n",
    "\n",
    "$$\n",
    "q^*(\\alpha)=Gam(a\\mid a_N,b_N)\n",
    "$$  \n",
    "\n",
    "其中  \n",
    "\n",
    "$$\n",
    "a_N=a_0+\\frac{M}{2}\\\\\n",
    "b_N=b_0+\\frac{1}{2}E[w^Tw]\n",
    "$$  \n",
    "\n",
    "同样的方式，我们可以求解$q^*(w)$:   \n",
    "\n",
    "$$\n",
    "ln\\ q^*(w)=ln\\ p(t\\mid w)+E_\\alpha[ln\\ p(w\\mid \\alpha)]+const\\\\\n",
    "=-\\frac{\\beta}{2}\\sum_{n=1}^N\\left \\{w^T\\phi(x_n)-t_n\\right \\}^2-\\frac{1}{2}E[\\alpha]w^Tw+const\\\\\n",
    "=-\\frac{1}{2}w^T(E[\\alpha]I+\\beta\\Phi^T\\Phi)w+\\beta w^T\\Phi^Tt+const\n",
    "$$  \n",
    "\n",
    "可以发现这是关于$w$的二次型，所以$q^*(w)$是一个高斯分布，通过配方可以得到：   \n",
    "\n",
    "$$\n",
    "q^*(w)=N(w\\mid m_N,S_N)\n",
    "$$  \n",
    "\n",
    "其中   \n",
    "\n",
    "$$\n",
    "m_N=\\beta S_N\\Phi^Tt\\\\\n",
    "S_N=(E[\\alpha]I+\\beta\\Phi^T\\Phi)^{-1}\n",
    "$$  \n",
    "\n",
    "细心的同学已经发现了，把$E[\\alpha]$换成$\\alpha$就是上一节推导出来的结果"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三.迭代求解\n",
    "\n",
    "可以发现$q^*(w)$与$q^*(\\alpha)$之间是相互耦合的，$q^*(w)$中需要用到$E[\\alpha]$，而$q^*(\\alpha)$中需要用到$E[w^Tw]$，而根据Gamma分布和高斯分布的性质，我们可以很方便的写出这两个期望的求解公式   \n",
    "\n",
    "$$\n",
    "E[\\alpha]=\\frac{a_N}{b_N}\\\\\n",
    "E[w^Tw]=m_N^Tm_N+Tr(S_N)\n",
    "$$  \n",
    "\n",
    "那么，我们初始随意为$E[\\alpha]$赋予一个大于0的值，就可以迭代起来了：$E_0[\\alpha]\\rightarrow E_0[w^Tw]\\rightarrow E_1[\\alpha]\\rightarrow E_1[w^Tw]\\rightarrow \\cdots$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 四.应用：预测分布\n",
    "对于新的样本点$\\hat{x}$，我们需要给出预测$\\hat{t}$，其概率预测同样使用上一节的贝叶斯积分，只不过将后验分布替换成变分分布   \n",
    "\n",
    "$$\n",
    "p(\\hat{t}\\mid \\hat{x},t)=\\int p(\\hat{t}\\mid\\hat{x},w)p(w\\mid t)dw\\\\\n",
    "\\simeq \\int p(\\hat{t}\\mid\\hat{x},w)q(w)dw\\\\\n",
    "=\\int N(\\hat{t}\\mid w^T\\phi(\\hat{x}),\\beta^{-1})N(w\\mid m_N,S_N)dw\\\\\n",
    "=N(\\hat{t}\\mid m_N^T\\phi(\\hat{x}),\\sigma^2(\\hat{x}))\n",
    "$$  \n",
    "\n",
    "这里$\\sigma^2(\\hat{x})=\\frac{1}{\\beta}+\\phi(\\hat{x})^TS_N\\phi(\\hat{x})$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### 五.代码实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "线性回归的vi实现，代码封装到ml_models.vi中\n",
    "\"\"\"\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "class LinearRegression(object):\n",
    "    def __init__(self, basis_func=None, beta=1e-12, tol=1e-5, epochs=100):\n",
    "        \"\"\"\n",
    "        :param basis_func: list,基函数列表，包括rbf,sigmoid,poly_{num},linear，默认None为linear，其中poly_{num}中的{num}表示多项式的最高阶数\n",
    "        :param beta: 生成t标签的高斯噪声，这里可以设置低一点\n",
    "        :param tol:  两次迭代参数平均绝对值变化小于tol则停止\n",
    "        :param epochs: 默认迭代次数\n",
    "        \"\"\"\n",
    "        if basis_func is None:\n",
    "            self.basis_func = ['linear']\n",
    "        else:\n",
    "            self.basis_func = basis_func\n",
    "        self.beta = beta\n",
    "        self.tol = tol\n",
    "        self.epochs = epochs\n",
    "        # 特征均值、标准差\n",
    "        self.feature_mean = None\n",
    "        self.feature_std = None\n",
    "        # 训练参数，也就是m_N\n",
    "        self.w = None\n",
    "\n",
    "    def _map_basis(self, X):\n",
    "        \"\"\"\n",
    "        将X进行基函数映射\n",
    "        :param X:\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        x_list = []\n",
    "        for basis_func in self.basis_func:\n",
    "            if basis_func == \"linear\":\n",
    "                x_list.append(X)\n",
    "            elif basis_func == \"rbf\":\n",
    "                x_list.append(np.exp(-0.5 * X * X))\n",
    "            elif basis_func == \"sigmoid\":\n",
    "                x_list.append(1 / (1 + np.exp(-1 * X)))\n",
    "            elif basis_func.startswith(\"poly\"):\n",
    "                p = int(basis_func.split(\"_\")[1])\n",
    "                for pow in range(1, p + 1):\n",
    "                    x_list.append(np.power(X, pow))\n",
    "        return np.concatenate(x_list, axis=1)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        self.feature_mean = np.mean(X, axis=0)\n",
    "        self.feature_std = np.std(X, axis=0) + 1e-8\n",
    "        X_ = (X - self.feature_mean) / self.feature_std\n",
    "        X_ = self._map_basis(X_)\n",
    "        X_ = np.c_[np.ones(X_.shape[0]), X_]\n",
    "        n_sample, n_feature = X_.shape\n",
    "\n",
    "        E_alpha = self.beta  # 初始就和beta设置一样，让它自动去调节（这里设置任意大于0的值都是可以的）\n",
    "        current_w = None\n",
    "        for _ in range(0, self.epochs):\n",
    "            S_N = np.linalg.inv(E_alpha * np.eye(n_feature) + self.beta * X_.T @ X_)\n",
    "            self.w = self.beta * S_N @ X_.T @ y.reshape((-1, 1))  # 即m_N\n",
    "            if current_w is not None and np.mean(np.abs(current_w - self.w)) < self.tol:\n",
    "                break\n",
    "            current_w = self.w\n",
    "            E_w = (self.w.T @ self.w)[0][0] + np.trace(S_N)\n",
    "            E_alpha = (n_feature - 1) / E_w  # 这里假设a_0,b_0都为0\n",
    "        #ml_models包中不会return,这里用作分析\n",
    "        return E_alpha,self.beta\n",
    "\n",
    "    def predict(self, X):\n",
    "        X_ = (X - self.feature_mean) / self.feature_std\n",
    "        X_ = self._map_basis(X_)\n",
    "        X_ = np.c_[np.ones(X_.shape[0]), X_]\n",
    "        return (self.w.T @ X_.T).reshape(-1)\n",
    "\n",
    "    def plot_fit_boundary(self, x, y):\n",
    "        \"\"\"\n",
    "        绘制拟合结果\n",
    "        :param x:\n",
    "        :param y:\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        plt.scatter(x[:, 0], y)\n",
    "        plt.plot(x[:, 0], self.predict(x), 'r')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 六.测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#造伪样本\n",
    "X=np.linspace(0,100,100)\n",
    "X=np.c_[X,np.ones(100)]\n",
    "w=np.asarray([3,2])\n",
    "Y=X.dot(w)\n",
    "X=X.astype('float')\n",
    "Y=Y.astype('float')\n",
    "X[:,0]+=np.random.normal(size=(X[:,0].shape))*3#添加噪声\n",
    "Y=Y.reshape(100,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#加噪声\n",
    "X=np.concatenate([X,np.asanyarray([[100,1],[101,1],[102,1],[103,1],[104,1]])])\n",
    "Y=np.concatenate([Y,np.asanyarray([[3000],[3300],[3600],[3800],[3900]])])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAH6hJREFUeJzt3X+UVeV97/H3h2HAQauDZjQ6YCFKNf6oYOYq0asxxgiSNJA0aU17E1bqvfSHWTU/FrnYrtb8MI2JbbVZ13qXjUaSm2isMUhSG0LRJE2zVAZBEJEwEZQZVMbCECsEYfjeP/Zz4DCcmTlnOMw5zP681jrrnP2cZ+/9bDbzfPd5nmc/WxGBmZnlz6haF8DMzGrDAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcmp0rQswkDe96U0xadKkWhfDzOyosmLFilcjomWwfHUdACZNmkR7e3uti2FmdlSR9EI5+cpuApLUIGmlpB+k5cmSnpC0QdJ3JI1J6WPTckf6flLRNm5M6eslzajskMzMrJoq6QO4AVhXtPxl4LaImAJsB65L6dcB2yPiTOC2lA9J5wDXAucCM4F/lNRweMU3M7OhKisASJoAvAf4WloWcCXwYMqyEJiTPs9Oy6Tv35Xyzwbuj4jdEbER6AAuqsZBmJlZ5cr9BXA78BlgX1o+CeiJiL1puRNoTZ9bgc0A6fsdKf/+9BLr7CdpnqR2Se3d3d0VHIqZmVVi0AAg6b3A1ohYUZxcImsM8t1A6xxIiLgrItoioq2lZdBObDMzG6JyRgFdCrxP0izgGOB4sl8EzZJGp6v8CcCWlL8TmAh0ShoNnABsK0ovKF7HzMyARSu7uHXJerb07OK05ibmzziLOdMOaSypikF/AUTEjRExISImkXXiPhoRfwg8BnwwZZsLPJw+L07LpO8fjeyxY4uBa9MoocnAFODJqh2JmdlRbtHKLm58aA1dPbsIoKtnFzc+tIZFK7uOyP4O507g/w18SlIHWRv/3Sn9buCklP4pYAFARKwFHgCeBX4IXB8RvYexfzOzEeXWJevZtefganHXnl5uXbL+iOyvohvBIuLHwI/T5+cpMYonIn4NfKif9b8IfLHSQpqZ5cGWnl0VpR8uzwVkZlYnTmtuqij9cNX1VBBmZnlQ6Pjt6tmFOHh4ZFNjA/NnnHVE9usAYGZWQ4WO30Lbf2HMfACtR3gUkAOAmVkNler4LVT+/7HgyiO6b/cBmJnV0HB3/BZzADAzq6Hh7vgt5gBgZlZD82ecRVPjwRMjH8mO32LuAzAzq6FCB+9wTf9QzAHAzKzG5kxrHZYKvy83AZmZ5ZQDgJlZTrkJyMysBoZz2uf+OACYmQ2zvnf/FqZ9BoY1CLgJyMxsmA33tM/9cQAwMxtmtbz7t1g5zwQ+RtKTkp6WtFbS51L6vZI2SlqVXlNTuiR9VVKHpNWSLiza1lxJG9Jrbn/7NDMbyWp592+xcn4B7AaujIgLgKnATEnT03fzI2Jqeq1KadeQPe5xCjAPuBNA0onATcDFZA+SuUnS+OodipnZ0aGWd/8WK+eZwBER/5UWG9MrBlhlNvCNtN7jZA+PPxWYASyNiG0RsR1YCsw8vOKbmR195kxr5UsfOJ/W5iZENvPnlz5wfn2OApLUAKwAzgTuiIgnJP0p8EVJfw0sAxZExG6gFdhctHpnSusv3cwsd2p192+xsjqBI6I3IqYCE4CLJJ0H3AicDfw34ESyh8RD9iyDQzYxQPpBJM2T1C6pvbu7u5zimZnZEFQ0CigiesgeCj8zIl5KzTy7ga9z4AHxncDEotUmAFsGSO+7j7sioi0i2lpaWiopnpmZVaCcUUAtkprT5ybgKuC51K6PJAFzgGfSKouBj6bRQNOBHRHxErAEuFrS+NT5e3VKMzOzGiinD+BUYGHqBxgFPBARP5D0qKQWsqadVcCfpPyPALOADmAn8DGAiNgm6QvA8pTv8xGxrXqHYmZmlVDEQAN6aqutrS3a29trXQwzs6OKpBUR0TZYPt8JbGaWUw4AZmY55dlAzcyGST1MAV3MAcDMbBjUyxTQxdwEZGY2DOplCuhiDgBmZsOgXqaALuYAYGY2DOplCuhiDgBmZsOgXqaALuZOYDOzYVDo6PUoIDOzHKqHKaCLuQnIzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyqpxHQh4j6UlJT0taK+lzKX2ypCckbZD0HUljUvrYtNyRvp9UtK0bU/p6STOO1EGZmdWLRSu7uPSWR5m84F+49JZHWbSyq9ZF2q+cXwC7gSsj4gJgKjAzPev3y8BtETEF2A5cl/JfB2yPiDOB21I+JJ0DXAucC8wE/jE9ZtLMbEQqzADa1bOL4MAMoPUSBAYNAJH5r7TYmF4BXAk8mNIXkj0YHmB2WiZ9/6704PjZwP0RsTsiNpI9M/iiqhyFmVkdqscZQIuV1QcgqUHSKmArsBT4JdATEXtTlk6gcHtbK7AZIH2/AzipOL3EOsX7miepXVJ7d3d35UdkZlYn6nEG0GJlBYCI6I2IqcAEsqv2t5bKlt7Vz3f9pffd110R0RYRbS0tLeUUz8ysLtXjDKDFKhoFFBE9wI+B6UCzpMJcQhOALelzJzARIH1/ArCtOL3EOmZmI049zgBarJxRQC2SmtPnJuAqYB3wGPDBlG0u8HD6vDgtk75/NCIipV+bRglNBqYAT1brQMzM6s2caa186QPn09rchIDW5ia+9IHz62ZCuHJmAz0VWJhG7IwCHoiIH0h6Frhf0s3ASuDulP9u4JuSOsiu/K8FiIi1kh4AngX2AtdHRC9mZiNYvc0AWkzZxXl9amtri/b29loXw8zsqCJpRUS0DZbPdwKbmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5VQ5TwSbKOkxSeskrZV0Q0r/rKQuSavSa1bROjdK6pC0XtKMovSZKa1D0oIjc0hmZlaOcp4Ithf4dEQ8Jek3gBWSlqbvbouIvy3OLOkcsqeAnQucBvybpN9KX98BvJvs+cDLJS2OiGercSBmZlaZQQNARLwEvJQ+vyZpHTDQ881mA/dHxG5gY3o05EXpu46IeB5A0v0prwOAmVkNVNQHIGkSMA14IiV9XNJqSfdIGp/SWoHNRat1prT+0s3MrAbKDgCSjgO+C3wiIn4F3AmcAUwl+4Xwd4WsJVaPAdL77meepHZJ7d3d3eUWz8zMKlRWAJDUSFb5fysiHgKIiFciojci9gH/xIFmnk5gYtHqE4AtA6QfJCLuioi2iGhraWmp9HjMzKxM5YwCEnA3sC4i/r4o/dSibO8HnkmfFwPXShoraTIwBXgSWA5MkTRZ0hiyjuLF1TkMMzOrVDmjgC4FPgKskbQqpf0F8GFJU8macTYBfwwQEWslPUDWubsXuD4iegEkfRxYAjQA90TE2ioei5mZVUARhzTD1422trZob2+vdTHMzI4qklZERNtg+XwnsJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeWUA4CZWU45AJiZ5ZQDgJlZTjkAmJnllAOAmVlOOQCYmeVUOY+EnCjpMUnrJK2VdENKP1HSUkkb0vv4lC5JX5XUIWm1pAuLtjU35d8gae6ROywzMxtMOb8A9gKfjoi3AtOB6yWdAywAlkXEFGBZWga4huw5wFOAecCdkAUM4CbgYrIHyN9UCBpmZjb8Bg0AEfFSRDyVPr8GrANagdnAwpRtITAnfZ4NfCMyjwPN6QHyM4ClEbEtIrYDS4GZVT0aMzMrW0V9AJImAdOAJ4BTIuIlyIIEcHLK1gpsLlqtM6X1l953H/MktUtq7+7urqR4ZmZWgbIDgKTjgO8Cn4iIXw2UtURaDJB+cELEXRHRFhFtLS0t5RbPzMwqVFYAkNRIVvl/KyIeSsmvpKYd0vvWlN4JTCxafQKwZYB0MzOrgXJGAQm4G1gXEX9f9NVioDCSZy7wcFH6R9NooOnAjtREtAS4WtL41Pl7dUozM7MaGF1GnkuBjwBrJK1KaX8B3AI8IOk64EXgQ+m7R4BZQAewE/gYQERsk/QFYHnK9/mI2FaVozAzs4op4pBm+LrR1tYW7e3ttS6GmdlRRdKKiGgbLJ/vBDYzyykHADOznHIAMDPLKQcAM7OccgAwM8spBwAzs5xyADAzyykHADOznHIAMDPLKQcAM7OccgAwM8spBwAzs5xyADAzyykHADOznHIAMDPLqXKeCHaPpK2SnilK+6ykLkmr0mtW0Xc3SuqQtF7SjKL0mSmtQ9KC6h+KmZlVopxfAPcCM0uk3xYRU9PrEQBJ5wDXAuemdf5RUoOkBuAO4BrgHODDKa+ZmdXIoI+EjIifSppU5vZmA/dHxG5go6QO4KL0XUdEPA8g6f6U99mKS2xmZlVxOH0AH5e0OjURjU9prcDmojydKa2/dDMzq5GhBoA7gTOAqcBLwN+ldJXIGwOkH0LSPEntktq7u7uHWDwzMxvMkAJARLwSEb0RsQ/4Jw4083QCE4uyTgC2DJBeatt3RURbRLS1tLQMpXhmZlaGIQUASacWLb4fKIwQWgxcK2mspMnAFOBJYDkwRdJkSWPIOooXD73YZmZ2uAbtBJZ0H3AF8CZJncBNwBWSppI142wC/hggItZKeoCsc3cvcH1E9KbtfBxYAjQA90TE2qofjZmZlU0RJZvi60JbW1u0t7fXuhhmZkcVSSsiom2wfL4T2MwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHJq0AAg6R5JWyU9U5R2oqSlkjak9/EpXZK+KqlD0mpJFxatMzfl3yBp7pE5HDMzK1c5vwDuBWb2SVsALIuIKcCytAxwDdlzgKcA84A7IQsYZI+SvJjsAfI3FYKGmZnVxqABICJ+CmzrkzwbWJg+LwTmFKV/IzKPA83pAfIzgKURsS0itgNLOTSomJnZMBpqH8ApEfESQHo/OaW3ApuL8nWmtP7SDyFpnqR2Se3d3d1DLJ6ZmQ2m2p3AKpEWA6QfmhhxV0S0RURbS0tLVQtnZmYHDDUAvJKadkjvW1N6JzCxKN8EYMsA6WZmViNDDQCLgcJInrnAw0XpH02jgaYDO1IT0RLgaknjU+fv1SnNzMxqZPRgGSTdB1wBvElSJ9lonluAByRdB7wIfChlfwSYBXQAO4GPAUTENklfAJanfJ+PiL4dy2ZmNowUUbIpvi60tbVFe3t7rYthZnZUkbQiItoGy+c7gc3McsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxy6rACgKRNktZIWiWpPaWdKGmppA3pfXxKl6SvSuqQtFrShdU4ADMzG5pq/AJ4Z0RMLXr4wAJgWURMAZalZYBrgCnpNQ+4swr7NjOzIToSTUCzgYXp80JgTlH6NyLzONBceLC8mZkNv8MNAAH8SNIKSfNS2inpQfCk95NTeiuwuWjdzpRmZmY1MOhD4QdxaURskXQysFTScwPkVYm0Qx5InALJPIDTTz/9MItnZmb9OaxfABGxJb1vBb4HXAS8UmjaSe9bU/ZOYGLR6hOALSW2eVdEtEVEW0tLy+EUz8zMBjDkXwCSjgVGRcRr6fPVwOeBxcBc4Jb0/nBaZTHwcUn3AxcDOwpNRWZmudPbC93d8Mor2evllw/+PGkS3HzzES3C4TQBnQJ8T1JhO9+OiB9KWg48IOk64EXgQyn/I8AsoAPYCXzsMPZtZlZ/Xn8d7rkH/vzPB8538snw6quwb9+h3zU1wZvfDGPGHJkyFhlyAIiI54ELSqT/J/CuEukBXD/U/ZmZ1cTu3fDNb8Ktt8IvflGdbc6enVXyb34znHJK9ip8Pu44UKku0+o73E5gM7OjT28vLF+eXakvX37k9nPzzTBnDpxzzrBV6pVwADCzkWHfPli1CrZuhWOOydrS/+qvYMOG6u5n+nSYPz+7im9oqO62h5kDgJnVrwj413+Fr3wFfvKTI7efqVPhM5+BD34QGhsHzLpoZRe3LlnPlr/8ISc0NSJBz849nNbcxDvPbuGx57rp6tlFg0RvBK3NTcyfcRbtL2zjvic20xtBg8SHL57IzXPOP3LHVAZlTfP1qa2tLdrb22tdDDOrpoisMv/KV7LKvdrOOw9uv51l28XtK7exbs8YfuO4Y4iAHbsOrqi39OzitFRBz5l26H2phcq+uEIXJW5gGsQooER3L/9j+ulHJAhIWlE0PU//+RwAzKxc+69+S1ScP/r+z3n6nn/mt9c+zuWbVtG059dV3//rp03k/7bN4e4z3sH4k8cftP/isp3Q1Mjrb+xlT2959VtTYwNf+sD5BwWBRSu7uPGhNeza01v14yhokPjll2ZVfbvlBgA3AZlZSYtWdnH/XQ/z9tX/zpn/uZnLN61kx/lXccPunfzemn/LMt14IP/V6VWpX5/Uwv952/u595yr+K+x4/an962UF63sYv6DT++v1Hf27GL+g0/vz19cWffs2lNRGXbt6eXWJesPCgC3Lll/RCt/gN4aX4A7AJiNAANdmff1k3sXc+an/4TWbQPfhzmHAzM5FsxtXzxoWTpOnMBPJ1/ITydP48mJ57FzTFO/ecePa+TXe/aVrGj7Vsqf+/7aQ67o9/QGn/v+WsaNGX3YlfWWnl0DLh8JDTUeGeQAYFZjlVTepdbr6tnF6T0v8ydPPsRHVj6SfXlj/+u9YwhlfPGEU/jJW97GvRf+Di+MP5W9DdWpOrbvHPhKvbgS7i/v9p176BlkO+U4rbnpkOWuKgWB/voAPnzxxBKpw8cBwKxIuZVxJZX2opVdfO77aw+qwMaPa+Sm3zkXOLjpoqtnF5/8ziraX9jGzZe3wrJl8KMfZa8XXzxou6Wu0Idir0bx/o/8HWtOnVKFrVVX30p5oHyHU1k3NTYwf8ZZB6XNn3FWRX0Ax45pYNcbvYdU9IVzXY+jgBwALNf6dhy+tnsvvfuyZoauPm3Mhavtvorz9Q0Cfdutj929k0teXM1lG1cy7fan+M3tL1WlEu/r2ZMnM/+aG9jylrey8q+zlvlLb3m0ale01dDU2MDY0aP6ba/vWyk3NzWWzNvc1Fiysu7vqnv8uEbe89unDjoKqLDcdxRQ6wCjiAa6MJgzrbXmFX5fHgVkdaVvhfzG3l527sn+jAtXUpU0jxT+EIv/YPtudzDHjmlgX9DvlWDTG79m+uY1vHvz0/xBzzpYv778Ay7HSSfBjBlw9dVw1VXQ2srkBf9S9lDETbe8B6CidcohwWiJPfsq32p/v4D6ft93VM78f376oP01jhK3fuiCfitfYEjNa0c7DwO1qio1Hrq1Cn9QxdstV3NTI++9oP8ruGoM3xu79w3aOp/lso1P8Y6NT/HW7k1D3lYpO8Yey79Pmsa/T85eW44/+ZA8AjamyruvSq7mCwGgmr8ACiN04NAKtm/aYGPuK21Oy2OFXikHABvUYH9M5VbOzU2NfPZ9/V+ZL1rZxWcXr93/832wq7+hKB4y2F9Fd8yeX/N7q5cy78nvMeFX2WMqvv6232F3QyMtr2+n5fUeWn/VzRnbOodUhp2NYxk3a+aBq/Uzzhi00m1uamTHrj0lr8xbm5v4jwVXllyv3CDX3NTIqpuurmidUhpHwXHHNO6/49UVb31zADgKlVshl3tlNdA+unp2lbyjsVCZt7+wjW89/mLZTQalbqQp7K/vz3aAxgZx3NjRg44CGcyYvXuY/exj/PETD3HmECvu3Q2NdB/bzKvHjufVcSdw1S8PTA72xqjR2VX6pGn8bNJUOk6a2O+kXsWVbUHfPoBijQ3i1g9eUPLfur9/z77bLvx/aB7XyI6dew5q8y5uHum7Tn9t2tX+hWe14QBQR8ppm3zn2S18Z/nmQyqKxlFw64emAoNfLRcqlIGuxMu5AmxsUNl3UBYrdcU6pGaHCM5/uYMrnm/nsk0ruajz2YrLUo4vvPM6Hjz/KnYc0//0u42jxO9fNJHvruga+N++RGVbMNAooFJ3sQ71CtvNI1bgAFAlff+oJp3UxOPPbz/oDr7WEqMABpo7pHGUQJRdyY4SnNDUWNbV8vhxjftHffR1pEeBlGqzLnQ8KvZxztaNXLZxJZdteopLX1h9ZAoxfTpP/O4f8UfbT+P1Cls6BhodMlBzmK+Urd7UbQCQNBP4B6AB+FpE3NJf3qEGgP7anAe6Mv6Lh1bvHxUiwR9enD2QvtxmkKbGBn73ba2DXikOh039dBxWexQIEZz16gtctvEpLt+4kss3razKZv/1ty7hml/8HID2Cefyk0lTWdH6Vp6ceF6/NyCVmjag3FFAlYwuMjsa1OVcQJIagDuAd5M9JH65pMURUbXf+KXanLfv3DPgOO1PPbCK4ibqCPh/jx98081gdu3p3X+TR70q62aZCCZv38LlG5/iso1P0fJ6Dxe8XN351FeeehY/O+NCLvij32P7uVP5yqPPZ30Syv7t4eCO5Tay8/RCBf0fc6a1ukI3G8Rw3wh2EdCRHidJekD8bKBqAeDWJetLjkve0xuHTPZUyD+EYcwlHcnKv7mpkd17S8+Z0jdfSRs3csdrT9K96Ptc+vwKxu3ZXdXy9Zx9Hs1z3puNfrnkEhg7dv93/f0iuzydi9kXTRp0+67QzapvuANAK7C5aLkTuLg4g6R5wDyA008/veIdDDSBU6nvqjnhU6HNfzBD6QP47PuyYZPFoz5O3bSe//XEQ5z5n5s5vedlTtj9erbCZ0tvZ2pZeztgw5vfwtcvmMXO1olc+T8/wPvefma/eZsH2I4rb7P6NNwBoNRQi4NqwYi4C7gLsj6ASncwUDNHqXlFqjXh02B9AIWO4NaiUUB9R4YAnPqrbj7xs2/z+2uWHryB1FNy2NMGnH32gXHql1+ePYC6H1OAvznc/ZlZ3RruANAJFE9/NwHYUs0dzJ9xVr/jzvtO9lTI37cPoKBhlPbPCwNZJX7mycfyfPfOfkcBtf3miYeMAjp39G6+/NoKzvtqqk7TTI1VnQPmz/4sq9Tf+U44/vhqbtnMRqjhDgDLgSmSJgNdwLXAH1RzB4WmhnJHARXSSo0CKlTmJTsat2+HhQvhk5/Mlosq9apV7G9/e7b9K6/M5oMxM6uiWgwDnQXcTjYM9J6I+GJ/eYf9PoBt2+BnP8um3l2yBDo6svRLLoGf/7x6+/nTP4VPfQrO7L9N3cxsqOpyGChARDwCPDJsO3z9dfjxjw/Mqf7cc5Vvo5zK/2Mfyyr1886rfPtmZjUwMp8HsG0bXHEFrFkztPVPOCFrT58xA979bhjCaCQzs3o3MgPA8cfDGWdkV+P33ZeljRuXVeqF1xln1LaMZmY1NjIDwOjR8L3vZZ+//e3alsXMrE6NqnUBzMysNhwAzMxyygHAzCynHADMzHLKAcDMLKccAMzMcsoBwMwspxwAzMxyqq4fCi+pG3jhMDfzJuDVKhSn3uXlOCE/x+rjHFmG8zh/MyJaBstU1wGgGiS1lzMr3tEuL8cJ+TlWH+fIUo/H6SYgM7OccgAwM8upPASAu2pdgGGSl+OE/Byrj3NkqbvjHPF9AGZmVloefgGYmVkJIzoASJopab2kDkkLal2eapE0UdJjktZJWivphpR+oqSlkjak9/G1Lms1SGqQtFLSD9LyZElPpOP8jqQxtS7j4ZLULOlBSc+l8/r2EXw+P5n+3z4j6T5Jx4yEcyrpHklbJT1TlFbyHCrz1VQ3rZZ0YS3KPGIDgKQG4A7gGuAc4MOSzqltqapmL/DpiHgrMB24Ph3bAmBZREwBlqXlkeAGYF3R8peB29Jxbgeuq0mpqusfgB9GxNnABWTHO+LOp6RW4M+Btog4D2gArmVknNN7gZl90vo7h9cAU9JrHnDnMJXxICM2AAAXAR0R8XxEvAHcD8yucZmqIiJeioin0ufXyCqLVrLjW5iyLQTm1KaE1SNpAvAe4GtpWcCVwIMpy1F/nJKOBy4H7gaIiDcioocReD6T0UCTpNHAOOAlRsA5jYifAtv6JPd3DmcD34jM40CzpFOHp6QHjOQA0ApsLlruTGkjiqRJwDTgCeCUiHgJsiABnFy7klXN7cBngH1p+SSgJyL2puWRcF7fAnQDX09NXV+TdCwj8HxGRBfwt8CLZBX/DmAFI++cFvR3DuuifhrJAUAl0kbUkCdJxwHfBT4REb+qdXmqTdJ7ga0RsaI4uUTWo/28jgYuBO6MiGnA64yA5p5SUhv4bGAycBpwLFlzSF9H+zkdTF38Px7JAaATmFi0PAHYUqOyVJ2kRrLK/1sR8VBKfqXwMzK9b61V+arkUuB9kjaRNeFdSfaLoDk1H8DIOK+dQGdEPJGWHyQLCCPtfAJcBWyMiO6I2AM8BFzCyDunBf2dw7qon0ZyAFgOTEmjC8aQdTQtrnGZqiK1g98NrIuIvy/6ajEwN32eCzw83GWrpoi4MSImRMQksvP3aET8IfAY8MGUbSQc58vAZklnpaR3Ac8yws5n8iIwXdK49P+4cKwj6pwW6e8cLgY+mkYDTQd2FJqKhlVEjNgXMAv4BfBL4C9rXZ4qHtd/J/u5uBpYlV6zyNrHlwEb0vuJtS5rFY/5CuAH6fNbgCeBDuCfgbG1Ll8Vjm8q0J7O6SJg/Eg9n8DngOeAZ4BvAmNHwjkF7iPr19hDdoV/XX/nkKwJ6I5UN60hGxU17GX2ncBmZjk1kpuAzMxsAA4AZmY55QBgZpZTDgBmZjnlAGBmllMOAGZmOeUAYGaWUw4AZmY59f8BXALME0CKeFoAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x19e43c1e828>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "lr=LinearRegression()\n",
    "alpha,beta=lr.fit(X[:,:-1],Y)\n",
    "lr.plot_fit_boundary(X[:,:-1],Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以发现成功省去了调节超参`alpha`的烦恼，本节由于将`alpha`也视作随机变量，并假设它来源于一个Gamma分布，最后通过求期望的方式得到一个更优的`alpha`，如下，可以发现最终迭代求解的较优的L2正则化系数为104，当然我们也可以将$\\beta$假设为一个Gamma分布一并放到变分框架中求解，这样可以免去对`beta`的设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "104.99882616406987"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alpha/beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
